import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import sys
from matplotlib.patches import Patch
import matplotlib.lines as mlines
sys.path.insert(1,'/REDACTED/fairness/code/config')
from configureColors import *
sys.path.insert(1,'/REDACTED/fairness/code/utilities')
import weightedBinscatter as wb
import estimators as et
#sys.path.append('packages')
#import binscatter
import matplotlib.font_manager as fm
import matplotlib
import statsmodels.formula.api as smf
from statsmodels.iolib.smpickle import load_pickle
import joblib
from trajectoryPlotters import *

# set plot defaults
plt.style.use('/REDACTED/fairness/code/config/fairness.mplstyle')
fe = fm.FontEntry(
    fname='/REDACTED/fairness/code/config/fonts/cmunrm.ttf',
    name='latex')
fm.fontManager.ttflist.insert(0, fe)
matplotlib.rcParams['font.family'] = fe.name

colorsBNB = cdictBlackNonBlack()
colorsENE = cdictEITCNon()
cb = colorsBNB['Black']['color']
ab = colorsBNB['Black']['alpha']
cnb = colorsBNB['non-Black']['color']
anb = colorsBNB['non-Black']['alpha']
ce = colorsENE['eitc']['color']
cne = colorsENE['non']['color']
ae = colorsENE['eitc']['alpha']
ane = colorsENE['non']['alpha']

## Linear Estimator
def regEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	model = sm.OLS.from_formula(outcome+' ~ '+pbvarb, data = dataset).fit(cov_type = 'HC1')
	coef = model.params[pbvarb]
	se = model.bse[pbvarb]
	black_aud_rate = model.params[0] + model.params[1]
	non_black_aud_rate = model.params[0] 
	return coef, se, black_aud_rate, non_black_aud_rate

## Probabilistic Estimator 
def chenEstimate(dataset, pbvarb = 'predicted_prob_black', outcome = 'audited'):
	black_aud_rate = (dataset[pbvarb]*dataset[outcome]).sum()/(dataset[pbvarb]).sum()
	non_black_aud_rate = ((1-dataset[pbvarb])*dataset[outcome]).sum()/(1-dataset[pbvarb]).sum()
	est =  black_aud_rate - non_black_aud_rate
	return est, black_aud_rate, non_black_aud_rate

## Get standard errors
def getSEs(dataset,pbvarb='predicted_prob_black',outcome='audited', seReg = 0.0):
	seMultiplier=getSEMultiplier(dataset,pbvarb)
	#seReg = regEstimate(dataset,pbvarb,outcome)[1]
	seChen = seReg*seMultiplier
	return seReg, seChen

def getSEMultiplier(dataset,pbvarb='predicted_prob_black'):
	return np.sqrt(dataset[pbvarb].var()/(dataset[pbvarb].mean()*(1-dataset[pbvarb].mean())))

### OUTPUT

out = '/REDACTED/'

### OP AUDIT
cols=['adj_gross_inc', 'predicted_prob_black', 'taxpayer_id_new', 'aud_no_research_audits']

dataBISG = pd.read_csv("/REDACTED/data/final/individualBISG2014_full_final.csv", usecols=cols)
len(dataBISG)
dataBISG= dataBISG[~dataBISG.predicted_prob_black.isna()]
len(dataBISG)

#########################
#### Audit Rates by Income and Race
#########################
# subset to positive AGI, split into 20 bins
dataBISG.adj_gross_inc.isnull().sum() # same missingness as isEIC? 1,223,360
dataBISG_noneg_agi = dataBISG[dataBISG['adj_gross_inc']>=0]
dataBISG_noneg_agi ["bin_20"] = pd.qcut(dataBISG_noneg_agi ["adj_gross_inc"], q = 20, labels = list(range(1,20+1)))
bin_20_mean = dataBISG_noneg_agi .groupby("bin_20")["adj_gross_inc"].mean()
bin_20_mean

dataBISG_noneg_agi.taxpayer_id_new.value_counts()
dataBISG_noneg_agi.taxpayer_id_new.isnull().sum()
dataBISG_noneg_agi.adj_gross_inc.isnull().sum()

dataBISG_noneg_agi = dataBISG_noneg_agi.drop(columns=['adj_gross_inc', 'taxpayer_id_new'])

dataBISG_noneg_agi.head()
len(dataBISG_noneg_agi)
~~~~~~~~~~~~~~~

# bootstrap each bin 100 times, calculataxpayer_idg outcomes of interest on each iteration
# compile all results in full_lin_disp_df and full_prob_disp_df
from timeit import default_timer as timer
full_lin_disp_df = pd.DataFrame(columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean", "iteration"])
full_prob_disp_df = pd.DataFrame(columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean", "iteration"])

n_iter_bootstrap = 100
for i in range(n_iter_bootstrap):
    start = timer()
    print("Bootstrap iteration: "+str(i))
    lin_ar_20 = []
    lin_ar_b_20 = []
    lin_ar_nb_20 = []
    lin_ar_se_20 = []
    prob_ar_20 = []
    prob_ar_b_20 = []
    prob_ar_nb_20 = []
    prob_ar_se_20 = []
    # loop over all bins
    for bin in range(1,len(dataBISG_noneg_agi .bin_20.value_counts())+1):
        temp_20 = dataBISG_noneg_agi [dataBISG_noneg_agi ["bin_20"] == bin]
        if len(temp_20) > 0:
            # bootstrap, calculate outcomes of interest, and append results to our dfs
            temp_20_b = temp_20.sample(frac=1, random_state=i, replace=True)
            lin_aud_bin_20,lin_aud_se_20, lin_black_aud_bin_20, lin_non_black_aud_bin_20 = regEstimate(temp_20_b, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
            lin_ar_20.append(lin_aud_bin_20)
            lin_ar_b_20.append(lin_black_aud_bin_20)
            lin_ar_nb_20.append(lin_non_black_aud_bin_20)
            lin_ar_se_20.append(lin_aud_se_20)
            prob_aud_bin_20, prob_black_aud_bin_20, prob_non_black_aud_bin_20 = chenEstimate(temp_20_b, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
            prob_aud_se_20 = getSEs(temp_20,pbvarb="predicted_prob_black",outcome="aud_no_research_audits", seReg = lin_aud_se_20)[1]
            prob_ar_20.append(prob_aud_bin_20)
            prob_ar_b_20.append(prob_black_aud_bin_20)
            prob_ar_nb_20.append(prob_non_black_aud_bin_20)
            prob_ar_se_20.append(prob_aud_se_20)
    lin_disp_df = pd.DataFrame(list(zip(lin_ar_20, lin_ar_b_20, lin_ar_nb_20, lin_ar_se_20, bin_20_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])
    prob_disp_df = pd.DataFrame(list(zip(prob_ar_20, prob_ar_b_20, prob_ar_nb_20, prob_ar_se_20, bin_20_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])
    lin_disp_df['iteration'] = i
    prob_disp_df['iteration'] = i
    full_lin_disp_df = pd.concat([full_lin_disp_df, lin_disp_df])
    full_prob_disp_df = pd.concat([full_prob_disp_df, prob_disp_df])
    if i in [10,50,75]:
        full_lin_disp_df.to_csv('/REDACTED/fig_4_bootstrap_lin_progress_V2.csv', index=False)
        full_prob_disp_df.to_csv('/REDACTED/fig_4_bootstrap_prob_progress_V2.csv', index=False)
    end=timer()
    print("Time: "+str(end-start))
        

full_prob_disp_df.to_csv('/REDACTED/full_prob_disp_df.csv', index=False)
full_lin_disp_df.to_csv('/REDACTED/full_lin_disp_df.csv', index=False)

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# one run without bootstraps, for point estimates

lin_ar_20 = []
lin_ar_b_20 = []
lin_ar_nb_20 = []
lin_ar_se_20 = []
prob_ar_20 = []
prob_ar_b_20 = []
prob_ar_nb_20 = []
prob_ar_se_20 = []
for bin in range(1,len(dataBISG_noneg_agi .bin_20.value_counts())+1):
    temp_20 = dataBISG_noneg_agi [dataBISG_noneg_agi ["bin_20"] == bin]
    if len(temp_20) > 0:
        lin_aud_bin_20,lin_aud_se_20, lin_black_aud_bin_20, lin_non_black_aud_bin_20 = regEstimate(temp_20, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
        lin_ar_20.append(lin_aud_bin_20)
        lin_ar_b_20.append(lin_black_aud_bin_20)
        lin_ar_nb_20.append(lin_non_black_aud_bin_20)
        lin_ar_se_20.append(lin_aud_se_20)
        prob_aud_bin_20, prob_black_aud_bin_20, prob_non_black_aud_bin_20 = chenEstimate(temp_20, pbvarb = "predicted_prob_black", outcome = "aud_no_research_audits")
        prob_aud_se_20 = getSEs(temp_20,pbvarb="predicted_prob_black",outcome="aud_no_research_audits", seReg = lin_aud_se_20)[1]
        prob_ar_20.append(prob_aud_bin_20)
        prob_ar_b_20.append(prob_black_aud_bin_20)
        prob_ar_nb_20.append(prob_non_black_aud_bin_20)
        prob_ar_se_20.append(prob_aud_se_20)
lin_disp_df = pd.DataFrame(list(zip(lin_ar_20, lin_ar_b_20, lin_ar_nb_20, lin_ar_se_20, bin_20_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])
prob_disp_df = pd.DataFrame(list(zip(prob_ar_20, prob_ar_b_20, prob_ar_nb_20, prob_ar_se_20, bin_20_mean)), columns = ["audit_rate","audit_rate_black","audit_rate_non_black","audit_rate_se", "bin_mean"])


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

########################
## Figure 4 (probabilistic)
########################

# clean prob_disp_df, get analytical standard errors
prob_disp_df = prob_disp_df.loc[prob_disp_df['bin_mean'] <= 200000]
prob_disp_df['audit_rate_black'] = prob_disp_df['audit_rate_black'] * 100
prob_disp_df['audit_rate_non_black'] = prob_disp_df['audit_rate_non_black'] * 100 
prob_disp_df['audit_rate_se'] = prob_disp_df['audit_rate_se'] * 100
prob_disp_df['lb_black'] = prob_disp_df['audit_rate_black'] - 1.96*prob_disp_df['audit_rate_se']
prob_disp_df['ub_black'] = prob_disp_df['audit_rate_black'] + 1.96*prob_disp_df['audit_rate_se']
prob_disp_df['lb_non_black'] = prob_disp_df['audit_rate_non_black'] - 1.96*prob_disp_df['audit_rate_se']
prob_disp_df['ub_non_black'] = prob_disp_df['audit_rate_non_black'] + 1.96*prob_disp_df['audit_rate_se']

### read in bootstrap estimates
full_prob_disp_df = pd.read_csv('/REDACTED/full_prob_disp_df.csv')
full_lin_disp_df = pd.read_csv('/REDACTED/full_lin_disp_df.csv')
full_prob_disp_df_f = full_prob_disp_df[full_prob_disp_df['bin_mean'] <= 200000]
len(full_prob_disp_df)
len(full_prob_disp_df_f)

lbs_black=[]
ubs_black=[]
lbs_non_black=[]
ubs_non_black=[]

# calculate 2.5 and 97.5 percentiles for bootstrap upper/lower bounds
for bin_avg in full_prob_disp_df_f.bin_mean.unique():
    tmp_bin = full_prob_disp_df_f[full_prob_disp_df_f.bin_mean == bin_avg]
    lbs_black.append((tmp_bin.audit_rate_black.sort_values().tolist()[2] + tmp_bin.audit_rate_black.sort_values().tolist()[3])/2)
    ubs_black.append((tmp_bin.audit_rate_black.sort_values().tolist()[96] + tmp_bin.audit_rate_black.sort_values().tolist()[97])/2)
    lbs_non_black.append((tmp_bin.audit_rate_non_black.sort_values().tolist()[2] + tmp_bin.audit_rate_non_black.sort_values().tolist()[3])/2)
    ubs_non_black.append((tmp_bin.audit_rate_non_black.sort_values().tolist()[96] + tmp_bin.audit_rate_non_black.sort_values().tolist()[97])/2)

prob_disp_df['lb_black_bootstrap'] = [100*x for x in lbs_black]
prob_disp_df['ub_black_bootstrap'] = [100*x for x in ubs_black]
prob_disp_df['lb_non_black_bootstrap'] = [100*x for x in lbs_non_black]
prob_disp_df['ub_non_black_bootstrap'] = [100*x for x in ubs_non_black]

# plot point estimates with bootstrapped standard errors
ax = plt.subplot(111)
arb = ax.plot(prob_disp_df.bin_mean, prob_disp_df.audit_rate_black, '-o', color = 'purple', alpha = 1, markersize = 2, label = "Black")
arbci = ax.fill_between(prob_disp_df.bin_mean, prob_disp_df.lb_black_bootstrap, prob_disp_df.ub_black_bootstrap, color = 'purple', alpha = 0.2)
arnb = ax.plot(prob_disp_df.bin_mean, prob_disp_df.audit_rate_non_black, '-o', color = 'purple', alpha = 0.4, markersize = 2, label = "Non-Black")
arnbci = ax.fill_between(prob_disp_df.bin_mean, prob_disp_df.lb_non_black_bootstrap, prob_disp_df.ub_non_black_bootstrap, color = 'purple', alpha = 0.2)
ax.set_ylabel('Audit Rate (%)')
ax.set_xlabel('Reported Income ($)')
handles, labels = ax.get_legend_handles_labels()
handles.append(Patch(facecolor = "purple", alpha=0.2, label = '95% CI'))
ax.legend(handles=handles)
plt.savefig('/REDACTED/prob_disp_inc_plot_20_bootstrap.png')
plt.close()

### compare width of bounds
prob_disp_df['analytical_width_black'] = prob_disp_df['ub_black'] - prob_disp_df['lb_black']
prob_disp_df['bootstrap_width_black'] = prob_disp_df['ub_black_bootstrap'] - prob_disp_df['lb_black_bootstrap']
prob_disp_df[['analytical_width_black', 'bootstrap_width_black']]  ## bootstrap CI width smaller in every bin

prob_disp_df['analytical_width_non_black'] = prob_disp_df['ub_non_black'] - prob_disp_df['lb_non_black']
prob_disp_df['bootstrap_width_non_black'] = prob_disp_df['ub_non_black_bootstrap'] - prob_disp_df['lb_non_black_bootstrap']
prob_disp_df[['analytical_width_non_black', 'bootstrap_width_non_black']] ## bootstrap CI width smaller in every bin


########################
## Figure A.2 (linear)
########################

# clean lin_disp_df, get analytical standard errors
lin_disp_df = lin_disp_df.loc[lin_disp_df['bin_mean'] <= 200000]
lin_disp_df['audit_rate_black'] = lin_disp_df['audit_rate_black'] * 100
lin_disp_df['audit_rate_non_black'] = lin_disp_df['audit_rate_non_black'] * 100 
lin_disp_df['audit_rate_se'] = lin_disp_df['audit_rate_se'] * 100
lin_disp_df['lb_black'] = lin_disp_df['audit_rate_black'] - 1.96*lin_disp_df['audit_rate_se']
lin_disp_df['ub_black'] = lin_disp_df['audit_rate_black'] + 1.96*lin_disp_df['audit_rate_se']
lin_disp_df['lb_non_black'] = lin_disp_df['audit_rate_non_black'] - 1.96*lin_disp_df['audit_rate_se']
lin_disp_df['ub_non_black'] = lin_disp_df['audit_rate_non_black'] + 1.96*lin_disp_df['audit_rate_se']

full_lin_disp_df_f = full_lin_disp_df[full_lin_disp_df['bin_mean'] <= 200000]
len(full_lin_disp_df)
len(full_lin_disp_df_f)

lin_lbs_black=[]
lin_ubs_black=[]
lin_lbs_non_black=[]
lin_ubs_non_black=[]

# calculate 2.5 and 97.5 percentiles for bootstrap upper/lower bounds
for bin_avg in full_lin_disp_df_f.bin_mean.unique():
    tmp_bin = full_lin_disp_df_f[full_lin_disp_df_f.bin_mean == bin_avg]
    lin_lbs_black.append((tmp_bin.audit_rate_black.sort_values().tolist()[2] + tmp_bin.audit_rate_black.sort_values().tolist()[3])/2)
    lin_ubs_black.append((tmp_bin.audit_rate_black.sort_values().tolist()[96] + tmp_bin.audit_rate_black.sort_values().tolist()[97])/2)
    lin_lbs_non_black.append((tmp_bin.audit_rate_non_black.sort_values().tolist()[2] + tmp_bin.audit_rate_non_black.sort_values().tolist()[3])/2)
    lin_ubs_non_black.append((tmp_bin.audit_rate_non_black.sort_values().tolist()[96] + tmp_bin.audit_rate_non_black.sort_values().tolist()[97])/2)

lin_disp_df['lb_black_bootstrap'] = [100*x for x in lin_lbs_black]
lin_disp_df['ub_black_bootstrap'] = [100*x for x in lin_ubs_black]
lin_disp_df['lb_non_black_bootstrap'] = [100*x for x in lin_lbs_non_black]
lin_disp_df['ub_non_black_bootstrap'] = [100*x for x in lin_ubs_non_black]

# plot point estimates with bootstrapped standard errors
ax = plt.subplot(111)
arb = ax.plot(lin_disp_df.bin_mean, lin_disp_df.audit_rate_black, '-o', color = 'purple', alpha = 1, markersize = 2, label = "Black")
arbci = ax.fill_between(lin_disp_df.bin_mean, lin_disp_df.lb_black_bootstrap, lin_disp_df.ub_black_bootstrap, color = 'purple', alpha = 0.2)
arnb = ax.plot(lin_disp_df.bin_mean, lin_disp_df.audit_rate_non_black, '-o', color = 'purple', alpha = 0.4, markersize = 2, label = "Non-Black")
arnbci = ax.fill_between(lin_disp_df.bin_mean, lin_disp_df.lb_non_black_bootstrap, lin_disp_df.ub_non_black_bootstrap, color = 'purple', alpha = 0.2)
ax.set_ylabel('Audit Rate (%)')
ax.set_xlabel('Reported Income ($)')
handles, labels = ax.get_legend_handles_labels()
handles.append(Patch(facecolor = "purple", alpha=0.2, label = '95% CI'))
ax.legend(handles=handles)
plt.savefig('/REDACTED/lin_disp_inc_plot_20_bootstrap.png')
plt.close()

### compare width of bounds
lin_disp_df['analytical_width_black'] = lin_disp_df['ub_black'] - lin_disp_df['lb_black']
lin_disp_df['bootstrap_width_black'] = lin_disp_df['ub_black_bootstrap'] - lin_disp_df['lb_black_bootstrap']
lin_disp_df[['analytical_width_black', 'bootstrap_width_black']]  ## bootstrap CI width smaller in most bins
lin_disp_df['big_boot'] = np.where(lin_disp_df.analytical_width_black < lin_disp_df.bootstrap_width_black, 1, 0)
lin_disp_df[['analytical_width_black', 'bootstrap_width_black', 'big_boot']]

lin_disp_df['analytical_width_non_black'] = lin_disp_df['ub_non_black'] - lin_disp_df['lb_non_black']
lin_disp_df['bootstrap_width_non_black'] = lin_disp_df['ub_non_black_bootstrap'] - lin_disp_df['lb_non_black_bootstrap']
lin_disp_df[['analytical_width_non_black', 'bootstrap_width_non_black']] ## bootstrap CI width smaller in every bin